#------------------------------------------------------------------------------
# plot trends over time
# Figure 1. Pain diagnosis
# Figure S1. Enrollment 
# Figure 2. Trends among pain patients
#==============================================================================

load_library = c('bit64','data.table','fst','future.apply','stringr','logger','vroom')
invisible(lapply(load_library, function(x) library(x, character.only=TRUE, quietly= TRUE)))

library(ggplot2)
library(ggpubr)

lci = function(x,w) {
	wmean = weighted.mean(x, w, na.rm=TRUE)
	wsd = sqrt(Hmisc::wtd.var(x, weights=w, na.rm=TRUE))
	N = sum(w) 
	wmean - wsd/sqrt(N)*1.96
}
uci = function(x,w) {
	wmean = weighted.mean(x, w, na.rm=TRUE)
	wsd = sqrt(Hmisc::wtd.var(x, weights=w, na.rm=TRUE))
	N = sum(w) 
	wmean + wsd/sqrt(N)*1.96
}

bucket = '/Users/bk/Dropbox/project/OP/- posted/2021-06-08_covid_substitution_JAMA/dataverse'
data_path = file.path(bucket, 'data')

infile = file.path(data_path,'processed_data','ses_division_weekly_covid_0101_0930.fst')

# ----------
logger::log_info('now loading data ...', infile)

dt = read_fst(infile, as.data.table = TRUE)

dt[, c('year','week') := tstrsplit(date, 'W',fixed=TRUE)]
dt[, week := as.numeric(gsub('W','',week))]
dt[, date := as.Date(paste(year, week, 1, sep="-"), "%Y-%U-%u")]
dt[, month := month(date)]

pre_condition_list = c('all','backpain','pain','ord')
for (var in pre_condition_list){
	dt[[paste0(var,'_','sum_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_quantity')]])
	dt[[paste0(var,'_','sum_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_days')]])
	dt[[paste0(var,'_','sum_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_mme')]])
	dt[[paste0(var,'_','mean_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_quantity')]])
	dt[[paste0(var,'_','mean_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_days')]])
	dt[[paste0(var,'_','mean_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_mme')]])
}

# drop the first week
dt = dt[week > 1, ]

select_var = c('opioids','therapy','sum_opioid_days','sum_opioid_mme')

target_var = c(paste0('targetpain_mean_',select_var))
label_target_var = c(paste0('P(', c('Opioids','Therapy'), ' Takers) (%)'),
					c('Days of RX','Total MME'))
ylim_all = list(
	opioid=c(0,15),
	therapy=c(0,40),
	daysrx=c(15,25),
	mme=c(40,50))

for (var in select_var){
	mean_target_var = paste0('targetpain_mean_',var)
	mean_dt = dt[, lapply(.SD, weighted.mean, w=N, na.rm=TRUE), by='date', .SDcols=mean_target_var] 
	
	lci_dt = dt[, lapply(.SD, lci, w=N), by='date', .SDcols=mean_target_var] 
	uci_dt = dt[, lapply(.SD, uci, w=N), by='date', .SDcols=mean_target_var] 
	names(lci_dt) = paste0('lci_',names(lci_dt))
	names(uci_dt) = paste0('uci_',names(uci_dt))
	
	mean_dt = merge(mean_dt, lci_dt, by.x='date', by.y=c('lci_date'))
	mean_dt = merge(mean_dt, uci_dt, by.x='date', by.y=c('uci_date'))
	
	mean_dt[,date := as.Date(date)]
	mean_dt[, year := year(date)]
	mean_dt[, date := as.Date(gsub('2019','2020',date))]
	mean_dt[,year := factor(year, levels = c(2019,2020))]
	
	file_name = file.path(data_path,paste0('figure_2_',var,'_trends.eps'))
	label_y = label_target_var[match(var,select_var)]
	ylim = ylim_all[[match(var,select_var)]]

	if (grepl('P\\(',label_y)) {
		mean_dt[, (mean_target_var) := get(mean_target_var)*100]
		mean_dt[, (paste0('lci_',mean_target_var)) := get(paste0('lci_',mean_target_var))*100]
		mean_dt[, (paste0('uci_',mean_target_var)) := get(paste0('uci_',mean_target_var))*100]
	}

	p = ggplot(mean_dt[
		date <= as.Date('2020-09-30'),], aes(x=date,
			y=get(mean_target_var),group=year,
			color=year,fill=year,
		ymin=get(paste0('lci_',mean_target_var)),
		ymax=get(paste0('uci_',mean_target_var)))) +
		geom_line() + 
		geom_ribbon(alpha=0.3)+
		geom_point() + 
		scale_x_date(date_breaks = '1 month', date_labels = '%b')+
		#scale_color_manual(name='weekend',values = c('weekday'='black','weekend'='red'))+
		scale_color_manual(name='year',values = c('2019'='black','2020'='red'))+
		scale_fill_manual(name='year',values = c('2019'='gray','2020'='pink'))+
		geom_vline(xintercept=as.Date('2020-03-13'), col='blue',lwd=0.5)+
		geom_vline(xintercept=as.Date('2020-07-04'), col='black',lwd=0.5, linetype='dashed')+
		theme_bw() +
		theme(legend.position = 'top')+
		labs(y=paste0(label_y),x=NULL)
	p = p + ylim(ylim)
	
	ggsave(file_name, p, width=5, height=4)

}


